{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:19.639593Z",
     "start_time": "2025-09-02T12:25:19.407436Z"
    }
   },
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.\n",
      "Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.\n",
      "Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.\n",
      "See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([ 0.7041271, -0.710074 ,  0.5331804], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:22.068547Z",
     "start_time": "2025-09-02T12:25:19.670892Z"
    }
   },
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjUsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvWftoOwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIIVJREFUeJzt3Q1wVPW9//HvbrIJhJDE8JAYkxTm6hUZHoqAkDpT75WUqKigONc6XIzI6EiRAXGYSqs4Op0Jg/9/rShEe70Vpyp4sYKaim0mINYSCESwEB7EFkkQkvDQPEKez39+P+/uP4sB8rTZ727er5njye45u+fk52Y//H7ne85xOY7jCAAACrmDvQMAAFwKIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUCtoIbVmzRoZMWKEDBgwQKZMmSJFRUXB2hUAgFJBCal3331Xli5dKs8++6x88cUXMn78eMnKypLKyspg7A4AQClXMC4wa3pOkydPlldeecU+bmtrk7S0NFm0aJE89dRTfb07AAClIvt6g01NTVJcXCzLly/3Ped2uyUzM1MKCws7fE1jY6OdvEyonTt3ToYMGSIul6tP9hsA0HtM/6i2tlZSUlJsBqgJqTNnzkhra6skJSX5PW8eHz58uMPX5OTkyHPPPddHewgA6CtlZWWSmpqqJ6S6w/S6zDEsr+rqaklPT7e/XFxcXFD3DQDQdTU1NfYwz+DBgy+7Xp+H1NChQyUiIkIqKir8njePk5OTO3xNdHS0nS5mAoqQAoDQdaVDNn1e3RcVFSUTJ06UgoICv2NM5nFGRkZf7w4AQLGgDPeZobvs7GyZNGmS3HTTTfKb3/xG6uvrZd68ecHYHQCAUkEJqfvvv19Onz4tK1askPLycvnhD38on3zyyfeKKQAA/VtQzpPqjQNu8fHxtoCCY1IAEHo6+z3OtfsAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAqEVIAQDUIqQAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAqEVIAQDUIqQAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAqEVIAQDUIqQAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAqEVIAQDUIqQAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAqEVIAQDUIqQAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAqEVIAQDUIqQAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAqEVIAQDUIqQAAGoRUgAAtQgpAIBahBQAQC1CCgCgFiEFAFCLkAIAhE9IffbZZ3LXXXdJSkqKuFwu2bx5s99yx3FkxYoVcvXVV8vAgQMlMzNTjh496rfOuXPnZM6cORIXFycJCQkyf/58qaur6/lvAwDo3yFVX18v48ePlzVr1nS4fNWqVbJ69Wp59dVXZdeuXTJo0CDJysqShoYG3zomoEpKSiQ/P1/y8vJs8D366KM9+00AAOHH6QHz8k2bNvket7W1OcnJyc4LL7zge66qqsqJjo521q9fbx8fPHjQvm737t2+dbZs2eK4XC7n22+/7dR2q6ur7XuYOQAg9HT2e7xXj0kdO3ZMysvL7RCfV3x8vEyZMkUKCwvtYzM3Q3yTJk3yrWPWd7vdtufVkcbGRqmpqfGbAADhr1dDygSUkZSU5Pe8eexdZubDhw/3Wx4ZGSmJiYm+dS6Wk5Njw847paWl9eZuAwCUConqvuXLl0t1dbVvKisrC/YuAQBCLaSSk5PtvKKiwu9589i7zMwrKyv9lre0tNiKP+86F4uOjraVgO0nAED469WQGjlypA2agoIC33Pm+JE51pSRkWEfm3lVVZUUFxf71tm6dau0tbXZY1cAAHhFSheZ85m+/vprv2KJffv22WNK6enpsmTJEvnVr34l1113nQ2tZ555xp5TNWvWLLv+DTfcILfddps88sgjtky9ublZHn/8cfnpT39q1wMAwKerZYPbtm2zZYMXT9nZ2b4y9GeeecZJSkqypefTpk1zjhw54vceZ8+edR544AEnNjbWiYuLc+bNm+fU1tb2eukiAECnzn6Pu8x/JMSYIURT5WeKKDg+BQChp7Pf4yFR3QcA6J8IKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWpHB3gEA/ZPT0iJtTU127rS1ibjd4o6IEFdUlLgiI8XlcgV7F6EAIQWgT5lQunDihNQUF0tVUZE0nDghrfX14h44UAakpEjchAmSMGWKDBwxQtweT7B3F0FGSAHoM6bHdHbbNqnMy5MLpaUira2+ZW319XL+6FE5//XXUrVjhwyZPl2G33GHuKOigrrPCC5CCkDAOY4jrefPyzcvvSS1+/ZJW0PD5Va2vauTb70lrbW1cvV//Ie4o6P7cnehCIUTAAIeUC3V1XJ89Wqp3rnz8gHV/nVNTVL54Ydy+pNPpK25OeD7CZ0IKQABDSinuVmOr1kjVYWFXX59W2OjVGzeLPVffWXfC/0PIQUgoD2of6xaJdW7dnX7fZrPnpWqnTttwQX6H0IKQEACqvmf/5Tjr7wi1UVFPX4/M+zXduFCr+wbQguFEwB6PaBMoJSagNqzJ9i7gxBHSAHo1YBqOn1aSnNz7XlQQE8x3Aeg15jjRwQUehM9KQC9UyRRUyPHX35Zavbu7f0NuN0iXCapXyKkAPQ4oBpPnpTS116zJ+oGQsqcORIRExOQ94ZuDPcB6BF7DCqAATUgNVXiJ0wQV0REQN4futGTAtD9MvMzZ+wQX+2XXwZkGxGxsZJ83332YrPonwgpAN0KqIbSUin97W+lbv/+gG0nfuJESfy3fxOXOSaFfon/8wC6rKmiIuABZcT8y78QUP0cPSkAXSuSOHXKXkmi7sCBgG3H3PQw5T//U4bfeWfAtoHQQEgB6HRAXTh2TMpMD+rgwYBtx+XxyDUPPWQDirvzgn40gE4xPahAB5Th7UERUDDoSQG44t10G8rK5HhurtQHuAdlA+quuwgo+BBSAC47xHf+73+Xsv/6L6k/fDigx6AY4kNHGO4DcOky8xMnAh5QBkN8uBR6UgA6HOKzPajf/lbqjxwJ2HZcUVFyDQGFyyCkAHyvB1V/9KgNqPNHjwZuQxERcs3cuZI0c2bgtoGQx3AfAP8y8+PHAx9QIjagTJEEcDn0pAD4hvjM0J45BnX+668Dth13dLS9qrkd4uNqErgCQgqA7UHVHT5se1AX/vGPwG3I7ZYUM8R3992B2wbCCv+MAfq59mXmAQ0ol0uuefBBLnWEwIVUTk6OTJ48WQYPHizDhw+XWbNmyZGLKn8aGhpk4cKFMmTIEImNjZXZs2dLRUWF3zqlpaUyY8YMiYmJse+zbNkyaWlp6dqeA+iVIb7a/fvt7TYu/P3vAR3iS334YVskwRAfuqJLn5bt27fbANq5c6fk5+dLc3OzTJ8+Xerr633rPPHEE/LRRx/Jxo0b7fonT56Ue++917e8tbXVBlRTU5Ps2LFD3nzzTVm3bp2sWLGiSzsOoOfMJY5OvP66vSZfwLhc9jwoG1DcuBBd5HJMX7+bTp8+bXtCJox+/OMfS3V1tQwbNkzeeecdue++++w6hw8flhtuuEEKCwtl6tSpsmXLFrnzzjtteCUlJdl1Xn31Vfn5z39u3y8qKuqK262pqZH4+Hi7vbi4uO7uPtCve1D2GNSrr8qFb74J3IbcbjvER0Chu9/jPep3mzc3EhMT7by4uNj2rjIzM33rjBo1StLT021IGWY+duxYX0AZWVlZdodLSko63E5jY6Nd3n4C0P2Aqtm7V0rXrg1oQNkhvnnzJGnWLAIK3dbtkGpra5MlS5bIzTffLGPGjLHPlZeX255QQkKC37omkMwy7zrtA8q73LvsUsfCTOJ6p7S0tO7uNtDvVX74oZz47/+2d9YNJFNmzjEo9FS3Pz3m2NSBAwdkw4YNEmjLly+3vTbvVFZWFvBtAuHGaW2V8vfek2/festek68vLhYLBOU8qccff1zy8vLks88+k9TUVN/zycnJtiCiqqrKrzdlqvvMMu86RUVFfu/nrf7zrnOx6OhoOwHoWUCdfOcdU3Me2GvxmTJzbreBYPSkTI2FCahNmzbJ1q1bZeTIkX7LJ06cKB6PRwoKCnzPmRJ1U3KekZFhH5v5/v37pbKy0reOqRQ0B85Gjx7d898IwPdUbN4c8IAy7MViCSgEqydlhvhM5d4HH3xgz5XyHkMyx4kGDhxo5/Pnz5elS5faYgoTPIsWLbLBZCr7DFOybsJo7ty5smrVKvseTz/9tH1vektA72praZFKE1BmWD6QPSjvDQu5mjmCWYJ+qQ/fG2+8IQ899JDvZN4nn3xS1q9fb6vyTOXe2rVr/Ybyjh8/LgsWLJBPP/1UBg0aJNnZ2bJy5UqJjOxcZlKCDnSuiq/8f/7nux5UAJmASn3oIRlGQKELOvs93qPzpIKFkAKurPwPf5Bvf/97U4ob0O1c4y0zJ6AQgO9xLjALhJm25mZ7DOqUGeILYEBxw0L0BUIKCCNmYKRi0yY5+dZbAd2OOTnX3g/q7rsJKAQUZ9kBYRZQp9avD/i2zO02CCj0BXpSQDgN8a1fb8+JChR3VNR3NyykzBx9hJACQlxfDfH5blg4c2ZgtwO0w3AfEA5DfH1weTLvlSSAvkRPCgjlIb7/PQYV0CG+6Oj/f6IuF4tFHyOkgFC+1FGgh/ja3bAQCAZCCgjBK0n4zoMKpP+9YSFXM0cwEVJAiF2Lr+IPf5CTpsw8gCfq2iE+U2bOEB+CjJACQoi9WOzbbwd8O/aGhXffHfDtAFdCSAEhwBRG2GNQgR7i815Jgio+KEFIAcr1+Q0LGeKDIoQUoJw9UbePbljIEB+0IaQApRxTJPHBB4G/YWFkpO88KEAbQgrQOsRnqvgCXCRhblh4zUMPcbsNqMXAM6CQLZLoi6uZcz8oKEdPCtB2HpS51NG77wb2hoUez3cBxdXMoRwhBWi6kkRfDPFFRjLEh5DBcB+g5WrmfXEeFEN8CDH0pAAtNyw0ARXAq5nb86AIKIQYQgrQ0IP6/e/75EoSXM0coYbhPiCIvLd8DzQudYRQRU8KCOK5UK6ICHvPpoBezXzOHC51hJBFSAFBUv/VV3I6L0+c5ma/57+tr5e9585JbXOzDBswQDKGDZNBHk/XN+B229ttcKkjhDJCCghSsUT1nj3SWF7ud3zqWF2dPLt3r3xTVycNra0S5/HImKuukv8zebJ4utITcrm4YSHCAv1/IAiaz56Vivff93vuH3V18shf/yqHqqvlQmurmKv1VTc3y18rK2Xxrl1ytqGh00N8qQ8/bIskGOJDqOMTDASB6TWZY1Lt/aakxIZSR4rOnJH8kyev/MYulz0PygaUOd4FhDiG+4Bw4XYzxIewQ0gBYcBW8XmvxccQH8IIn2ZAiRlpaeK5RDn6iNhYGZeYeMnXmjJzjkEhHPGJBoLAEx8vif/+737PZaWkyLMTJsiAiAjfH2aEyyVDoqPl/06eLKMTEi57sVggHDHcBwSBe+BAuSojw5aht9bW2ufM9fRMUKXGxEjeiRO2ms/0oO4fOdIGVYfX4jPHoLjdBsIYIQUEgQmVuAkTZNjtt9vbc3gr/czz5rwoM12JvVgsAYUwx3AfEMRih6R77pHEW2+1w3advVmhveX7vHlczRz9Aj0pIIgiYmLsibeexEQ5t22bNFVWdrieOedpQFqaDM3KsnfsHUZAoZ8gpIAgMkETOWiQXH3ffRI3bpz8c8cOqSspsZdLamtslIjYWBmQmirxkybZaWB6Oifpol8hpAAlQ3+xY8bIoH/9V2m9cEGclhZ7O3kTSG6PR9wxMeLu5JAgEE741AOKelWu6GgbWAC+Q+EEAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQBAeIRUbm6ujBs3TuLi4uyUkZEhW7Zs8S1vaGiQhQsXypAhQyQ2NlZmz54tFRUVfu9RWloqM2bMkJiYGBk+fLgsW7ZMWlpaeu83AgD0z5BKTU2VlStXSnFxsezZs0duvfVWmTlzppSUlNjlTzzxhHz00UeyceNG2b59u5w8eVLuvfde3+tbW1ttQDU1NcmOHTvkzTfflHXr1smKFSt6/zcDAIQ+p4euuuoq5/XXX3eqqqocj8fjbNy40bfs0KFDjtlEYWGhffzxxx87brfbKS8v962Tm5vrxMXFOY2NjZ3eZnV1tX1fMwcAhJ7Ofo93+5iU6RVt2LBB6uvr7bCf6V01NzdLZmamb51Ro0ZJenq6FBYW2sdmPnbsWElKSvKtk5WVJTU1Nb7eWEcaGxvtOu0nAED463JI7d+/3x5vio6Olscee0w2bdoko0ePlvLycomKipKEhAS/9U0gmWWGmbcPKO9y77JLycnJkfj4eN+UlpbW1d0GAPSHkLr++utl3759smvXLlmwYIFkZ2fLwYMHJZCWL18u1dXVvqmsrCyg2wMA6BDZ1ReY3tK1115rf544caLs3r1bXnrpJbn//vttQURVVZVfb8pU9yUnJ9ufzbyoqMjv/bzVf951OmJ6bWYCAPQvPT5Pqq2tzR4zMoHl8XikoKDAt+zIkSO25NwcszLM3AwXVlZW+tbJz8+35exmyBAAgG73pMyw2+23326LIWpra+Wdd96RTz/9VP70pz/ZY0Xz58+XpUuXSmJiog2eRYsW2WCaOnWqff306dNtGM2dO1dWrVplj0M9/fTT9twqekoAgB6FlOkBPfjgg3Lq1CkbSubEXhNQP/nJT+zyF198Udxutz2J1/SuTOXe2rVrfa+PiIiQvLw8eyzLhNegQYPsMa3nn3++K7sBAOgnXKYOXUKMKUE3IWmKKEyPDQAQnt/jXLsPAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAgFqEFABALUIKAKAWIQUAUIuQAgCoRUgBANQipAAAahFSAAC1CCkAQHiG1MqVK8XlcsmSJUt8zzU0NMjChQtlyJAhEhsbK7Nnz5aKigq/15WWlsqMGTMkJiZGhg8fLsuWLZOWlpae7AoAIAx1O6R2794tr732mowbN87v+SeeeEI++ugj2bhxo2zfvl1Onjwp9957r295a2urDaimpibZsWOHvPnmm7Ju3TpZsWJFz34TAED4cbqhtrbWue6665z8/HznlltucRYvXmyfr6qqcjwej7Nx40bfuocOHXLMZgoLC+3jjz/+2HG73U55eblvndzcXCcuLs5pbGzs1Parq6vte5o5ACD0dPZ7vFs9KTOcZ3pDmZmZfs8XFxdLc3Oz3/OjRo2S9PR0KSwstI/NfOzYsZKUlORbJysrS2pqaqSkpKTD7TU2Ntrl7ScAQPiL7OoLNmzYIF988YUd7rtYeXm5REVFSUJCgt/zJpDMMu867QPKu9y7rCM5OTny3HPPdXVXAQAhrks9qbKyMlm8eLG8/fbbMmDAAOkry5cvl+rqat9k9gMAEP66FFJmOK+yslJuvPFGiYyMtJMpjli9erX92fSITEFEVVWV3+tMdV9ycrL92cwvrvbzPvauc7Ho6GiJi4vzmwAA4a9LITVt2jTZv3+/7Nu3zzdNmjRJ5syZ4/vZ4/FIQUGB7zVHjhyxJecZGRn2sZmb9zBh55Wfn2+DZ/To0b35uwEA+tMxqcGDB8uYMWP8nhs0aJA9J8r7/Pz582Xp0qWSmJhog2fRokU2mKZOnWqXT58+3YbR3LlzZdWqVfY41NNPP22LMUyPCQCAbhdOXMmLL74obrfbnsRrqvJM5d7atWt9yyMiIiQvL08WLFhgw8uEXHZ2tjz//PO9vSsAgBDnMnXoEmJMCXp8fLwtouD4FACEns5+j3PtPgCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAALUIKQCAWpESghzHsfOamppg7woAoBu839/e7/OwCqmzZ8/aeVpaWrB3BQDQA7W1tRIfHx9eIZWYmGjnpaWll/3l+jvzLxUT5GVlZRIXFxfs3VGLduoc2qlzaKfOMT0oE1ApKSmXXS8kQ8rt/u5QmgkoPgRXZtqIdroy2qlzaKfOoZ2urDOdDAonAABqEVIAALVCMqSio6Pl2WeftXNcGu3UObRT59BOnUM79S6Xc6X6PwAAgiQke1IAgP6BkAIAqEVIAQDUIqQAAGqFZEitWbNGRowYIQMGDJApU6ZIUVGR9CefffaZ3HXXXfZMbZfLJZs3b/ZbbmphVqxYIVdffbUMHDhQMjMz5ejRo37rnDt3TubMmWNPNkxISJD58+dLXV2dhIucnByZPHmyDB48WIYPHy6zZs2SI0eO+K3T0NAgCxculCFDhkhsbKzMnj1bKioq/NYxVzWZMWOGxMTE2PdZtmyZtLS0SLjIzc2VcePG+U48zcjIkC1btviW00YdW7lypf3bW7Jkie852ipAnBCzYcMGJyoqyvnd737nlJSUOI888oiTkJDgVFRUOP3Fxx9/7Pzyl7903n//fVOZ6WzatMlv+cqVK534+Hhn8+bNzpdffuncfffdzsiRI50LFy741rntttuc8ePHOzt37nT+8pe/ONdee63zwAMPOOEiKyvLeeONN5wDBw44+/btc+644w4nPT3dqaur863z2GOPOWlpaU5BQYGzZ88eZ+rUqc6PfvQj3/KWlhZnzJgxTmZmprN3717b7kOHDnWWL1/uhIsPP/zQ+eMf/+h89dVXzpEjR5xf/OIXjsfjse1m0EbfV1RU5IwYMcIZN26cs3jxYt/ztFVghFxI3XTTTc7ChQt9j1tbW52UlBQnJyfH6Y8uDqm2tjYnOTnZeeGFF3zPVVVVOdHR0c769evt44MHD9rX7d6927fOli1bHJfL5Xz77bdOOKqsrLS/8/bt231tYr6MN27c6Fvn0KFDdp3CwkL72HyJuN1up7y83LdObm6uExcX5zQ2Njrh6qqrrnJef/112qgDtbW1znXXXefk5+c7t9xyiy+kaKvACanhvqamJikuLrbDV+2v42ceFxYWBnXftDh27JiUl5f7tZG5PpYZFvW2kZmbIb5Jkyb51jHrm7bctWuXhKPq6mq/ixObz1Fzc7NfO40aNUrS09P92mns2LGSlJTkWycrK8teQLSkpETCTWtrq2zYsEHq6+vtsB9t9H1mOM8M17VvE4O2CpyQusDsmTNn7B9S+//Jhnl8+PDhoO2XJiagjI7ayLvMzM14eHuRkZH2C9y7Tjhpa2uzxw5uvvlmGTNmjH3O/J5RUVE2rC/XTh21o3dZuNi/f78NJXNMxRxL2bRpk4wePVr27dtHG7VjAvyLL76Q3bt3f28Zn6fACamQArr7r98DBw7I559/HuxdUen666+3gWR6m++9955kZ2fL9u3bg71bqpjbbixevFjy8/NtwRb6TkgN9w0dOlQiIiK+VzFjHicnJwdtvzTxtsPl2sjMKysr/ZabCiNT8Rdu7fj4449LXl6ebNu2TVJTU33Pm9/TDB9XVVVdtp06akfvsnBhegDXXnutTJw40VZFjh8/Xl566SXa6KLhPPM3c+ONN9pRBzOZIF+9erX92fSIaKvAcIfaH5P5QyooKPAbyjGPzXAFREaOHGk/8O3byIx5m2NN3jYyc/PHZP7wvLZu3Wrb0hy7CgempsQElBm6Mr+baZf2zOfI4/H4tZMpUTclwu3byQyFtQ908y9pU6pthsPClfkcNDY20kbtTJs2zf6epsfpncwxXXMah/dn2ipAnBAsQTeVauvWrbNVao8++qgtQW9fMRPuTIWRKWE1k/lf+Otf/9r+fPz4cV8JummTDz74wPnb3/7mzJw5s8MS9AkTJji7du1yPv/8c1uxFE4l6AsWLLBl+J9++qlz6tQp33T+/Hm/kmFTlr5161ZbMpyRkWGni0uGp0+fbsvYP/nkE2fYsGFhVTL81FNP2YrHY8eO2c+KeWyqPP/85z/b5bTRpbWv7jNoq8AIuZAyXn75ZfthMOdLmZJ0c65Pf7Jt2zYbThdP2dnZvjL0Z555xklKSrKBPm3aNHsOTHtnz561oRQbG2tLYOfNm2fDL1x01D5mMudOeZnQ/tnPfmZLrmNiYpx77rnHBll733zzjXP77bc7AwcOtOe0PPnkk05zc7MTLh5++GHnBz/4gf1bMl+Y5rPiDSiDNup8SNFWgcGtOgAAaoXUMSkAQP9CSAEA1CKkAABqEVIAALUIKQCAWoQUAEAtQgoAoBYhBQBQi5ACAKhFSAEA1CKkAABqEVIAANHq/wFEloyMo1Jw6wAAAABJRU5ErkJggg=="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:23.615519Z",
     "start_time": "2025-09-02T12:25:22.187985Z"
    }
   },
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:23.634310Z",
     "start_time": "2025-09-02T12:25:23.622813Z"
    }
   },
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action = model(state).argmax().item()\n",
    "\n",
    "    if random.random() < 0.01:\n",
    "        action = random.choice(range(11))\n",
    "\n",
    "    #离散动作连续化\n",
    "    action_continuous = action\n",
    "    action_continuous /= 10\n",
    "    action_continuous *= 4\n",
    "    action_continuous -= 2\n",
    "\n",
    "    return action, action_continuous\n",
    "\n",
    "\n",
    "get_action([0.29292667, 0.9561349, 1.0957013])"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, -1.2)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:23.653788Z",
     "start_time": "2025-09-02T12:25:23.643453Z"
    }
   },
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action, action_continuous = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action_continuous])\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 5000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 5000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200, 0), 200)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:23.684561Z",
     "start_time": "2025-09-02T12:25:23.668694Z"
    }
   },
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/s9/yht5_svd6mxft5fpm7ht48f80000gn/T/ipykernel_92433/3503280997.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:256.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.9959, -0.0903,  1.6575],\n",
       "         [-0.9748,  0.2232, -0.8847],\n",
       "         [-0.9518, -0.3068,  1.3767],\n",
       "         [-0.8172, -0.5764, -0.0760],\n",
       "         [-0.8899,  0.4561,  5.0242],\n",
       "         [-0.9596, -0.2812, -1.2232],\n",
       "         [-0.9772,  0.2125, -0.5790],\n",
       "         [-0.9648, -0.2630, -1.5703],\n",
       "         [-0.9971,  0.0755,  1.7725],\n",
       "         [-0.8277,  0.5611,  1.0430],\n",
       "         [-0.9938,  0.1112,  1.1848],\n",
       "         [-0.9904,  0.1380,  0.7081],\n",
       "         [ 0.1845,  0.9828,  1.0279],\n",
       "         [-0.9505, -0.3107,  1.1428],\n",
       "         [-0.8878,  0.4603, -2.2501],\n",
       "         [-0.9593, -0.2824, -1.8802],\n",
       "         [-0.8860, -0.4636, -0.4846],\n",
       "         [-0.7786, -0.6275,  3.7447],\n",
       "         [-0.8432,  0.5376, -1.7849],\n",
       "         [-0.8943, -0.4475,  0.8721],\n",
       "         [-0.9898,  0.1428, -1.0642],\n",
       "         [-0.9984,  0.0562, -1.5144],\n",
       "         [-0.5303, -0.8478,  1.3940],\n",
       "         [-0.9644,  0.2646,  0.0313],\n",
       "         [-0.8819, -0.4715,  1.1555],\n",
       "         [-0.7829,  0.6221, -0.1163],\n",
       "         [-0.8766, -0.4812, -0.6605],\n",
       "         [-0.9001, -0.4357, -0.7853],\n",
       "         [-0.9573,  0.2891,  0.7765],\n",
       "         [-0.5315, -0.8470, -0.4866],\n",
       "         [-0.9893,  0.1458,  0.8213],\n",
       "         [-0.9739,  0.2271, -0.2997],\n",
       "         [-0.7843, -0.6203, -2.3686],\n",
       "         [-0.7865,  0.6175, -0.6994],\n",
       "         [-0.9959,  0.0899, -1.2517],\n",
       "         [-0.8689, -0.4950,  0.2619],\n",
       "         [-0.9076, -0.4198,  1.2903],\n",
       "         [-0.5881, -0.8088,  2.2406],\n",
       "         [-0.8577, -0.5142,  0.9819],\n",
       "         [-0.8467, -0.5321,  0.1561],\n",
       "         [-0.9269, -0.3753, -2.9752],\n",
       "         [-0.9728,  0.2318, -2.9322],\n",
       "         [-0.7055, -0.7087, -1.9571],\n",
       "         [-0.9975, -0.0712,  1.3462],\n",
       "         [-0.9129, -0.4081,  0.9982],\n",
       "         [-0.9950, -0.1003, -1.6646],\n",
       "         [-0.9145, -0.4045, -1.2536],\n",
       "         [-0.8717, -0.4900, -0.9672],\n",
       "         [-0.8282, -0.5604, -0.3882],\n",
       "         [-0.9384,  0.3455,  0.0328],\n",
       "         [-0.9990, -0.0439,  1.6574],\n",
       "         [-0.9942, -0.1076,  1.2277],\n",
       "         [-0.9996,  0.0287,  1.3009],\n",
       "         [-0.9994,  0.0358,  1.1204],\n",
       "         [-0.9322, -0.3620,  1.0897],\n",
       "         [-0.8660, -0.5000,  0.4771],\n",
       "         [-0.9765,  0.2156,  5.1262],\n",
       "         [-0.7486,  0.6630,  4.7669],\n",
       "         [-0.9938,  0.1116, -1.0492],\n",
       "         [-0.9855,  0.1698,  0.9374],\n",
       "         [-0.9741, -0.2259, -3.1367],\n",
       "         [-0.9976, -0.0685, -3.1862],\n",
       "         [-0.9554, -0.2952,  1.4219],\n",
       "         [-0.9825,  0.1863,  0.5616]]),\n",
       " tensor([[1],\n",
       "         [7],\n",
       "         [8],\n",
       "         [7],\n",
       "         [1],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [1],\n",
       "         [7],\n",
       "         [8],\n",
       "         [7],\n",
       "         [2],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [1],\n",
       "         [7],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [0],\n",
       "         [7],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [8],\n",
       "         [1],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [8],\n",
       "         [8],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [1],\n",
       "         [8],\n",
       "         [8],\n",
       "         [8],\n",
       "         [8],\n",
       "         [7],\n",
       "         [1],\n",
       "         [1],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [7],\n",
       "         [8],\n",
       "         [7]]),\n",
       " tensor([[ -9.5867],\n",
       "         [ -8.5851],\n",
       "         [ -8.1987],\n",
       "         [ -6.3885],\n",
       "         [ -9.6451],\n",
       "         [ -8.3101],\n",
       "         [ -8.6045],\n",
       "         [ -8.5158],\n",
       "         [ -9.7174],\n",
       "         [ -6.5906],\n",
       "         [ -9.3239],\n",
       "         [ -9.0696],\n",
       "         [ -2.0260],\n",
       "         [ -8.1162],\n",
       "         [ -7.5999],\n",
       "         [ -8.5072],\n",
       "         [ -7.0972],\n",
       "         [ -7.4723],\n",
       "         [ -6.9448],\n",
       "         [ -7.2472],\n",
       "         [ -9.1038],\n",
       "         [ -9.7494],\n",
       "         [ -4.7343],\n",
       "         [ -8.2596],\n",
       "         [ -7.1609],\n",
       "         [ -6.1037],\n",
       "         [ -7.0117],\n",
       "         [ -7.3026],\n",
       "         [ -8.1735],\n",
       "         [ -4.5664],\n",
       "         [ -9.0397],\n",
       "         [ -8.4924],\n",
       "         [ -6.6744],\n",
       "         [ -6.1800],\n",
       "         [ -9.4691],\n",
       "         [ -6.8917],\n",
       "         [ -7.5035],\n",
       "         [ -5.3423],\n",
       "         [ -6.8659],\n",
       "         [ -6.6621],\n",
       "         [ -8.4861],\n",
       "         [ -9.3151],\n",
       "         [ -5.9249],\n",
       "         [ -9.6098],\n",
       "         [ -7.5062],\n",
       "         [ -9.5261],\n",
       "         [ -7.5842],\n",
       "         [ -7.0085],\n",
       "         [ -6.5015],\n",
       "         [ -7.7784],\n",
       "         [ -9.8731],\n",
       "         [ -9.3563],\n",
       "         [ -9.8605],\n",
       "         [ -9.7726],\n",
       "         [ -7.7994],\n",
       "         [ -6.8770],\n",
       "         [-11.1821],\n",
       "         [ -8.1159],\n",
       "         [ -9.2901],\n",
       "         [ -8.9151],\n",
       "         [ -9.4742],\n",
       "         [-10.4592],\n",
       "         [ -8.2804],\n",
       "         [ -8.7595]]),\n",
       " tensor([[-9.8755e-01, -1.5729e-01,  1.3498e+00],\n",
       "         [-9.6768e-01,  2.5218e-01, -5.9736e-01],\n",
       "         [-9.2935e-01, -3.6919e-01,  1.3266e+00],\n",
       "         [-8.2822e-01, -5.6040e-01, -3.8823e-01],\n",
       "         [-9.7649e-01,  2.1556e-01,  5.1262e+00],\n",
       "         [-9.7604e-01, -2.1759e-01, -1.3141e+00],\n",
       "         [-9.7388e-01,  2.2707e-01, -2.9970e-01],\n",
       "         [-9.8317e-01, -1.8268e-01, -1.6476e+00],\n",
       "         [-9.9999e-01, -3.9154e-03,  1.5891e+00],\n",
       "         [-8.6951e-01,  4.9391e-01,  1.5838e+00],\n",
       "         [-9.9924e-01,  3.8985e-02,  1.4481e+00],\n",
       "         [-9.9578e-01,  9.1747e-02,  9.3161e-01],\n",
       "         [ 1.0612e-01,  9.9435e-01,  1.5850e+00],\n",
       "         [-9.3216e-01, -3.6204e-01,  1.0897e+00],\n",
       "         [-8.4321e-01,  5.3758e-01, -1.7849e+00],\n",
       "         [-9.8245e-01, -1.8655e-01, -1.9720e+00],\n",
       "         [-9.0199e-01, -4.3177e-01, -7.1235e-01],\n",
       "         [-6.7482e-01, -7.3798e-01,  3.0341e+00],\n",
       "         [-8.0765e-01,  5.8967e-01, -1.2617e+00],\n",
       "         [-8.7768e-01, -4.7924e-01,  7.1650e-01],\n",
       "         [-9.8291e-01,  1.8408e-01, -8.3710e-01],\n",
       "         [-9.9234e-01,  1.2353e-01, -1.3523e+00],\n",
       "         [-5.1077e-01, -8.5972e-01,  4.5818e-01],\n",
       "         [-9.6884e-01,  2.4768e-01,  3.4979e-01],\n",
       "         [-8.5769e-01, -5.1417e-01,  9.8186e-01],\n",
       "         [-7.9734e-01,  6.0353e-01,  4.7032e-01],\n",
       "         [-8.9741e-01, -4.4120e-01, -9.0142e-01],\n",
       "         [-9.2059e-01, -3.9053e-01, -9.9204e-01],\n",
       "         [-9.7189e-01,  2.3543e-01,  1.1134e+00],\n",
       "         [-5.7328e-01, -8.1936e-01, -1.0019e+00],\n",
       "         [-9.9560e-01,  9.3658e-02,  1.0507e+00],\n",
       "         [-9.7377e-01,  2.2753e-01, -9.3933e-03],\n",
       "         [-8.6103e-01, -5.0855e-01, -2.7138e+00],\n",
       "         [-7.8293e-01,  6.2211e-01, -1.1627e-01],\n",
       "         [-9.8975e-01,  1.4279e-01, -1.0642e+00],\n",
       "         [-8.6864e-01, -4.9544e-01,  1.0653e-02],\n",
       "         [-8.8189e-01, -4.7146e-01,  1.1555e+00],\n",
       "         [-5.3033e-01, -8.4779e-01,  1.3940e+00],\n",
       "         [-8.3709e-01, -5.4706e-01,  7.7623e-01],\n",
       "         [-8.4993e-01, -5.2689e-01, -1.2297e-01],\n",
       "         [-9.7414e-01, -2.2593e-01, -3.1367e+00],\n",
       "         [-9.3383e-01,  3.5771e-01, -2.6384e+00],\n",
       "         [-7.8433e-01, -6.2035e-01, -2.3686e+00],\n",
       "         [-9.8953e-01, -1.4436e-01,  1.4728e+00],\n",
       "         [-8.9428e-01, -4.4750e-01,  8.7212e-01],\n",
       "         [-9.9981e-01, -1.9487e-02, -1.6198e+00],\n",
       "         [-9.4121e-01, -3.3782e-01, -1.4370e+00],\n",
       "         [-8.9986e-01, -4.3618e-01, -1.2147e+00],\n",
       "         [-8.4702e-01, -5.3156e-01, -6.8853e-01],\n",
       "         [-9.4534e-01,  3.2608e-01,  4.1195e-01],\n",
       "         [-9.9361e-01, -1.1286e-01,  1.3845e+00],\n",
       "         [-9.8488e-01, -1.7324e-01,  1.3271e+00],\n",
       "         [-9.9892e-01, -4.6360e-02,  1.5025e+00],\n",
       "         [-9.9953e-01, -3.0506e-02,  1.3273e+00],\n",
       "         [-9.1294e-01, -4.0809e-01,  9.9819e-01],\n",
       "         [-8.6040e-01, -5.0962e-01,  2.2203e-01],\n",
       "         [-9.9938e-01, -3.5123e-02,  5.0479e+00],\n",
       "         [-8.8994e-01,  4.5607e-01,  5.0242e+00],\n",
       "         [-9.8815e-01,  1.5351e-01, -8.4545e-01],\n",
       "         [-9.9380e-01,  1.1117e-01,  1.1848e+00],\n",
       "         [-9.9765e-01, -6.8536e-02, -3.1862e+00],\n",
       "         [-9.9619e-01,  8.7179e-02, -3.1176e+00],\n",
       "         [-9.3281e-01, -3.6037e-01,  1.3805e+00],\n",
       "         [-9.8931e-01,  1.4581e-01,  8.2134e-01]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:23.776416Z",
     "start_time": "2025-09-02T12:25:23.769104Z"
    }
   },
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 3] -> [b, 11]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 11] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.4452],\n",
       "        [0.7655],\n",
       "        [0.3563],\n",
       "        [0.4277],\n",
       "        [1.3802],\n",
       "        [0.8441],\n",
       "        [0.6558],\n",
       "        [0.9688],\n",
       "        [0.4921],\n",
       "        [0.3655],\n",
       "        [0.3429],\n",
       "        [0.3593],\n",
       "        [0.5032],\n",
       "        [0.3133],\n",
       "        [1.3086],\n",
       "        [1.0815],\n",
       "        [0.5778],\n",
       "        [1.0815],\n",
       "        [1.1353],\n",
       "        [0.2510],\n",
       "        [0.8216],\n",
       "        [0.9765],\n",
       "        [0.4455],\n",
       "        [0.5204],\n",
       "        [0.3090],\n",
       "        [0.6032],\n",
       "        [0.6277],\n",
       "        [0.6701],\n",
       "        [0.3614],\n",
       "        [0.4909],\n",
       "        [0.3348],\n",
       "        [0.5762],\n",
       "        [1.2183],\n",
       "        [0.7290],\n",
       "        [0.8806],\n",
       "        [0.3384],\n",
       "        [0.3366],\n",
       "        [0.6214],\n",
       "        [0.2704],\n",
       "        [0.3625],\n",
       "        [1.4912],\n",
       "        [1.5354],\n",
       "        [1.0413],\n",
       "        [0.3641],\n",
       "        [0.2802],\n",
       "        [1.0157],\n",
       "        [0.8438],\n",
       "        [0.7264],\n",
       "        [0.5283],\n",
       "        [0.5289],\n",
       "        [0.4493],\n",
       "        [0.3368],\n",
       "        [0.3624],\n",
       "        [0.3209],\n",
       "        [0.3006],\n",
       "        [0.2852],\n",
       "        [1.4284],\n",
       "        [1.2838],\n",
       "        [0.8125],\n",
       "        [0.3158],\n",
       "        [1.5677],\n",
       "        [1.6029],\n",
       "        [0.3647],\n",
       "        [0.3981]], grad_fn=<GatherBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:23.870932Z",
     "start_time": "2025-09-02T12:25:23.865098Z"
    }
   },
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "    \"\"\"以下是主要的Double DQN和DQN的区别\"\"\"\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 11] -> [b]\n",
    "    #target = target.max(dim=1)[0]\n",
    "\n",
    "    #使用model计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        model_target = model(next_state)\n",
    "\n",
    "    #取分数最高的下标\n",
    "    #[b, 11] -> [b, 1]\n",
    "    model_target = model_target.max(dim=1)[1]\n",
    "    model_target = model_target.reshape(-1, 1)\n",
    "\n",
    "    #以这个下标取next_value当中的值\n",
    "    #[b, 11] -> [b]\n",
    "    target = target.gather(dim=1, index=model_target)\n",
    "    \"\"\"以上是主要的Double DQN和DQN的区别\"\"\"\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-9.2348],\n",
       "        [-7.9327],\n",
       "        [-7.8607],\n",
       "        [-5.8708],\n",
       "        [-8.2452],\n",
       "        [-7.4461],\n",
       "        [-8.0398],\n",
       "        [-7.5335],\n",
       "        [-9.2947],\n",
       "        [-6.1350],\n",
       "        [-8.9378],\n",
       "        [-8.7681],\n",
       "        [-1.5110],\n",
       "        [-7.8216],\n",
       "        [-6.4874],\n",
       "        [-7.4058],\n",
       "        [-6.4616],\n",
       "        [-6.6287],\n",
       "        [-6.0290],\n",
       "        [-7.0113],\n",
       "        [-8.3756],\n",
       "        [-8.8455],\n",
       "        [-4.4955],\n",
       "        [-7.8165],\n",
       "        [-6.8958],\n",
       "        [-5.6564],\n",
       "        [-6.3168],\n",
       "        [-6.5693],\n",
       "        [-7.8482],\n",
       "        [-3.9116],\n",
       "        [-8.7380],\n",
       "        [-7.9791],\n",
       "        [-5.3290],\n",
       "        [-5.5888],\n",
       "        [-8.6640],\n",
       "        [-6.4832],\n",
       "        [-7.2006],\n",
       "        [-4.9057],\n",
       "        [-6.6464],\n",
       "        [-6.2188],\n",
       "        [-6.9497],\n",
       "        [-7.9030],\n",
       "        [-4.7309],\n",
       "        [-9.2326],\n",
       "        [-7.2602],\n",
       "        [-8.5392],\n",
       "        [-6.6837],\n",
       "        [-6.1992],\n",
       "        [-5.8838],\n",
       "        [-7.3410],\n",
       "        [-9.5113],\n",
       "        [-9.0096],\n",
       "        [-9.4668],\n",
       "        [-9.4167],\n",
       "        [-7.5248],\n",
       "        [-6.5373],\n",
       "        [-9.7865],\n",
       "        [-6.7633],\n",
       "        [-8.5627],\n",
       "        [-8.5790],\n",
       "        [-7.9034],\n",
       "        [-8.8995],\n",
       "        [-7.9333],\n",
       "        [-8.4314]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:24.222798Z",
     "start_time": "2025-09-02T12:25:24.213401Z"
    }
   },
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        _, action_continuous = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action_continuous])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1655.3153321247662"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false,
    "ExecuteTime": {
     "end_time": "2025-09-02T12:25:42.959438Z",
     "start_time": "2025-09-02T12:25:24.300220Z"
    }
   },
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 50 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 200 0 -1432.4926927842182\n",
      "20 4400 200 0 -1302.1442762119111\n",
      "40 5000 200 200 -1020.5148368226135\n",
      "60 5000 200 200 -326.0384319854087\n",
      "80 5000 200 200 -873.5121309042464\n",
      "100 5000 200 200 -191.83464328705722\n",
      "120 5000 200 200 -262.1552549377575\n",
      "140 5000 200 200 -172.038112852835\n",
      "160 5000 200 200 -726.2437875001659\n",
      "180 5000 200 200 -679.3539137821294\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "test(play=True)"
   ],
   "execution_count": 11,
   "outputs": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
